#!/usr/bin/env bash
# Author: Noel Chalmers

# set -x #echo on

# #################################################
# helper functions
# #################################################
function display_help()
{
  echo "rocHPL run helper script"
  echo "./run_rochpl "
  echo "    [-P]    Specific MPI grid size: the number of         "
  echo "            rows in MPI grid.                             "
  echo "    [-Q]    Specific MPI grid size: the number of         "
  echo "            columns in MPI grid.                          "
  echo "    [-p]    Specific node-local MPI grid size: the number "
  echo "            of rows in node-local MPI grid. Must evenly   "
  echo "            divide P.                                     "
  echo "    [-q]    Specific node-local MPI grid size: the number "
  echo "            of columns in node-local MPI grid. Must evenly"
  echo "            divide Q.                                     "
  echo "    [-N]    Specific matrix size: the number of           "
  echo "            rows/columns in global matrix.                "
  echo "    [--NB]  Specific panel size: the number of            "
  echo "            rows/columns in panels.                       "
  echo "    [-f]    Specific split fraction: the percentange to   "
  echo "            split the trailing submatrix.                 "
  echo "    [-i]    Input file. When set, all other commnand      "
  echo "            line parameters are ignored, and problem      "
  echo "            parameters are read from input file.          "
  echo "    [-h|--help] prints this help message                  "
  echo "    [--version] Print rocHPL version number.              "
}

# This function is helpful for dockerfiles that do not have sudo installed, but the default user is root
# true is a system command that completes successfully, function returns success
# prereq: ${ID} must be defined before calling
supported_distro( )
{
  if [ -z ${ID+foo} ]; then
    printf "supported_distro(): \$ID must be set\n"
    exit 2
  fi

  case "${ID}" in
    ubuntu|centos|rhel|fedora|sles)
        true
        ;;
    *)  printf "This script is currently supported on Ubuntu, CentOS, RHEL, Fedora and SLES\n"
        exit 2
        ;;
  esac
}

# #################################################
# Pre-requisites check
# #################################################
# Exit code 0: alls well
# Exit code 1: problems with getopt
# Exit code 2: problems with supported platforms

# check if getopt command is installed
type getopt > /dev/null
if [[ $? -ne 0 ]]; then
  echo "This script uses getopt to parse arguments; try installing the util-linux package";
  exit 1
fi

# os-release file describes the system
if [[ -e "/etc/os-release" ]]; then
  source /etc/os-release
else
  echo "This script depends on the /etc/os-release file"
  exit 2
fi

# The following function exits script if an unsupported distro is detected
supported_distro

# #################################################
# global variables
# #################################################
# Grab options from CMake config
rochpl_bin=@CMAKE_INSTALL_PREFIX@/bin/rochpl
rocm_dir=@ROCM_PATH@
rocblas_dir=@ROCBLAS_LIB_PATH@
blas_dir=@HPL_BLAS_DIR@

P=1
Q=1
p=-1
q=-1
N=45312
NB=384
frac=0.6

filename=HPL.dat
inputfile=false
cmdrun=false

export LD_LIBRARY_PATH=${rocblas_dir}:${blas_dir}:${rocm_dir}/lib:$LD_LIBRARY_PATH

oversubscribe=true

# #################################################
# Parameter parsing
# #################################################

# check if we have a modern version of getopt that can handle whitespace and long parameters
getopt -T
if [[ $? -eq 4 ]]; then
  GETOPT_PARSE=$(getopt --name "${0}" --longoptions NB:,help,version, --options hP:Q:p:q:N:i:f: -- "$@")
else
  echo "Need a new version of getopt"
  exit 1
fi

if [[ $? -ne 0 ]]; then
  echo "getopt invocation failed; could not parse the command line";
  exit 1
fi

eval set -- "${GETOPT_PARSE}"

while true; do
  case "${1}" in
    -h|--help)
        display_help
        exit 0
        ;;
    --version)
        ${rochpl_bin} --version
        exit 0
        ;;
    -P)
        P=${2}
        shift 2 ;;
    -Q)
        Q=${2}
        shift 2 ;;
    -p)
        p=${2}
        shift 2 ;;
    -q)
        q=${2}
        shift 2 ;;
    -N)
        N=${2}
        cmdrun=true
        shift 2 ;;
    --NB)
        NB=${2}
        cmdrun=true
        shift 2 ;;
    -f)
        frac=${2}
        shift 2 ;;
    -i)
        filename=${2}
        inputfile=true
        shift 2 ;;
    --) shift ; break ;;
    *)  echo "Unexpected command line parameter received; aborting";
        exit 1
        ;;
  esac
done

#if nothing but np and ppn parameters where given, default to running
# with default input file
if [[ "${inputfile}" == false && "${cmdrun}" == false ]]; then
  inputfile=true
fi

np=$(($P*$Q))
if [[ "$np" -lt 1 ]]; then
  echo "Invalid MPI grid parameters; aborting";
  exit 1
fi

#######################################
# Now figure out the CPU core mappings
#######################################

# Get local process numbering
set +u
if [[ -n ${OMPI_COMM_WORLD_LOCAL_RANK+x} ]]; then
  globalRank=$OMPI_COMM_WORLD_RANK
  globalSize=$OMPI_COMM_WORLD_SIZE
  rank=$OMPI_COMM_WORLD_LOCAL_RANK
  size=$OMPI_COMM_WORLD_LOCAL_SIZE
elif [[ -n ${SLURM_LOCALID+x} ]]; then
  globalRank=$SLURM_PROCID
  globalSize=$SLURM_NTASKS
  rank=$SLURM_LOCALID
  size=$SLURM_TASKS_PER_NODE
  #Slurm can return a string like "2(x2),1". Get the first number
  size=$(echo $size | sed -r 's/^([^.]+).*$/\1/; s/^[^0-9]*([0-9]+).*$/\1/')
fi
set -u

#Determing node-local grid size
if [[ "$p" -lt 1 && "$q" -lt 1 ]]; then
  # no node-local grid was specified, pick defaults
  q=$(( (Q<=size) ? Q : size))

  if [[ $((size % q)) -gt 0 ]]; then
    echo "Invalid MPI grid parameters; Unable to form node-local grid; aborting";
    exit 1
  fi

  p=$(( size/q ))

elif [[ "$p" -lt 1 ]]; then
  #q was specified

  if [[ $((size % q)) -gt 0 ]]; then
    echo "Invalid MPI grid parameters; Unable to form node-local grid; aborting";
    exit 1
  fi

  p=$(( size/q ))

elif [[ "$q" -lt 1 ]]; then
  #p was specified

  if [[ $((size % p)) -gt 0 ]]; then
    echo "Invalid MPI grid parameters; Unable to form node-local grid; aborting";
    exit 1
  fi

  q=$(( size/p ))

else
  #Both p and q were specified
  if [[ $size -ne $((p*q)) ]]; then
    echo "Invalid MPI grid parameters; Unable to form node-local grid; aborting";
    exit 1
  fi
fi

# Check that the columns are evenly divided among nodes
if [[ $((P % p)) -gt 0 ]]; then
  echo "Invalid MPI grid parameters; Must have the same number of P rows on every node; aborting";
  exit 1
fi

# Check that the rows are evenly divided among nodes
if [[ $((Q % q)) -gt 0 ]]; then
  echo "Invalid MPI grid parameters; Must have the same number of Q columns on every node; aborting";
  exit 1
fi

# count the number of physical cores on node
num_cpu_cores=$(lscpu | grep "Core(s)" | awk '{print $4}')
num_cpu_sockets=$(lscpu | grep Socket | awk '{print $2}')
total_cpu_cores=$(($num_cpu_cores*$num_cpu_sockets))

# Ranks in different processes rows will take distinct chunks of cores
row_stride=$((total_cpu_cores/p))
col_stride=$((row_stride/q))

myp=$((rank%p))
myq=$((rank/p))

#Although ranks are column-major order, we select GPUs in row-major order on node
mygpu=$((myq+myp*q))

# Try to detect special Bard-peak core mapping
if [[ -n ${HPL_PLATFORM+x} ]]; then
  platform=$HPL_PLATFORM
else
  platform=$(cat /sys/class/dmi/id/product_name)
fi

if [[ "$platform" == "BardPeak" || "$platform" == "HPE_CRAY_EX235A" ]]; then
  # Special core mapping for BardPeak

  # Debug
  # if [[ $globalRank == 0 ]]; then
  #   echo "BardPeak platform detected"
  # fi

  # Sanity check
  if [[ $size -gt 8 ]]; then
    echo "Unsupported number of ranks on BardPeak platform; aborting";
    exit 1
  fi

  # GCD0 cores="48-55"
  # GCD1 cores="56-63"
  # GCD2 cores="16-23"
  # GCD3 cores="24-31"
  # GCD4 cores="0-7"
  # GCD5 cores="8-15"
  # GCD6 cores="32-39"
  # GCD7 cores="40-47"

  root_cores=(48 56 16 24 0 8 32 40)
  root_core=${root_cores[mygpu]}

  # First omp place is the root core
  omp_places="{$root_core}"

  # First assign the CCD
  for i in $(seq $((root_core+1)) $((root_core+8-1)))
  do
    omp_places+=",{$i}"
  done
  omp_num_threads=8

  places="{$root_core-$((root_core+7))}"

  # Loop through unassigned CCDs
  for c in $(seq $((mygpu+size)) $size 7)
  do
    iroot_core=${root_cores[c]}
    for i in $(seq $((iroot_core)) $((iroot_core+8-1)))
    do
      omp_places+=",{$i}"
    done
    omp_num_threads=$((omp_num_threads+8))
    places+=",{$iroot_core-$((iroot_core+7))}"
  done

  if [[ "${oversubscribe}" == true ]]; then
    # Add cores from different columns, without their root cores
    for j in $(seq 0  $((q-1)))
    do
      if [[ "$j" == "$myq" ]]; then
        continue
      fi
      for jj in $(seq 0 $size 7)
      do
        q_gpu=$((jj+j+myp*q))
        q_core=$((root_cores[q_gpu]))
        offset=$(( (q_gpu>=size) ? 0 : 1))
        for i in $(seq $((q_core+offset)) $((q_core+8-1)))
        do
          omp_places+=",{$i}"
        done
        omp_num_threads=$((omp_num_threads+8-offset))
        places+=",{$((q_core+offset))-$((q_core+7))}"
      done
    done
  fi

else
  # Default core mapping
  root_core=$((myp*row_stride + myq*col_stride))

  omp_num_threads=${col_stride}
  # First omp place is the root core
  omp_places="{$root_core}"

  # Make contiuguous chunk of cores (to maximize L1/L2 locality)
  for i in $(seq $((root_core+1)) $((root_core+col_stride-1)))
  do
    omp_places+=",{$i}"
  done

  if [[ $col_stride -gt 1 ]]; then
    places="{$root_core-$((root_core+col_stride-1))}"
  else
    places="{$root_core}"
  fi

  if [[ "${oversubscribe}" == true ]]; then
    # Add cores from different columns, without their root cores
    for j in $(seq 0 $((q-1)))
    do
      if [[ "$j" == "$myq" ]]; then
        continue
      fi
      q_core=$((myp*row_stride + j*col_stride))
      for i in $(seq $((q_core+1)) $((q_core+col_stride-1)))
      do
        omp_places+=",{$i}"
      done
      omp_num_threads=$((omp_num_threads+col_stride-1))

      if [[ $col_stride -gt 2 ]]; then
        places+=",{$((q_core+1))-$((q_core+col_stride-1))}"
      elif [[ $col_stride -gt 1 ]]; then
        places+=",{$((q_core+1))}"
      fi

    done
  fi
fi

# Export OpenMP config
export OMP_NUM_THREADS=${omp_num_threads}
export OMP_PLACES=${omp_places}
export OMP_PROC_BIND=true

if [[ $globalRank -lt $size ]]; then
  echo "Node Binding: Process $rank [(p,q)=($myp,$myq)] CPU Cores: $omp_num_threads - $places"
fi

rochpl_args="-P ${P} -Q ${Q} -p ${p} -q ${q} -f ${frac}"
if [[ "${inputfile}" == true ]]; then
  rochpl_args+=" -i ${filename}"
else
  rochpl_args+=" -N ${N} -NB ${NB}"
fi

#run
${rochpl_bin} ${rochpl_args}
